#! /bin/bash
#
# Copyright (c) 2017, 2018 NVIDIA CORPORATION.
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA

set -eu

# NVIDIA_VISIBLE_DEVICES="" *or* NVIDIA_VISIBLE_DEVICES="void"
# GPU support was explicitly disabled, exit early.
if [ -z "${NVIDIA_VISIBLE_DEVICES-x}" ] || [ "${NVIDIA_VISIBLE_DEVICES:-}" = "void" ]; then
    exit 0
fi

# https://github.com/nvidia/nvidia-container-runtime#cuda_version
if [ -n "${CUDA_VERSION:-}" ] && [ -z "${NVIDIA_REQUIRE_CUDA:-}" ]; then
    # Legacy CUDA image: default to all devices and all driver capabilities.
    if [ -z "${NVIDIA_VISIBLE_DEVICES+x}" ]; then
	NVIDIA_VISIBLE_DEVICES="all"
    fi
    if [ -z "${NVIDIA_DRIVER_CAPABILITIES:-}" ]; then
	NVIDIA_DRIVER_CAPABILITIES="all"
    fi
    if [[ "${CUDA_VERSION}" =~ ^[0-9]+\.[0-9]+ ]]; then
        NVIDIA_REQUIRE_CUDA="cuda>=${BASH_REMATCH[0]}"
    fi
else
    # NVIDIA_VISIBLE_DEVICES unset and it's not a legacy CUDA image.
    # This is not a GPU image, exit early.
    if [ -z "${NVIDIA_VISIBLE_DEVICES+x}" ]; then
	exit 0
    fi
fi

export PATH=$PATH:/usr/sbin:/usr/bin:/sbin:/bin
if ! which nvidia-container-cli >/dev/null; then
    echo "ERROR: Missing tool nvidia-container-cli, see https://github.com/NVIDIA/libnvidia-container" >&2
    exit 1
fi

in_userns() {
    [ -e /proc/self/uid_map ] || { echo no; return; }
    while read line; do
        fields=$(echo $line | awk '{ print $1 " " $2 " " $3 }')
        [ "$fields" = "0 0 4294967295" ] && { echo no; return; } || true
        echo $fields | grep -q " 0 1$" && { echo userns-root; return; } || true
    done < /proc/self/uid_map

    if [ -e /proc/1/uid_map ]; then
        if [ "$(cat /proc/self/uid_map)" = "$(cat /proc/1/uid_map)" ]; then
            echo userns-root
            return
        fi
    fi
    echo yes
}

get_ldconfig() {
    which "ldconfig.real" || which "ldconfig"
    return $?
}

capability_to_cli() {
    case "$1" in
        compute)  echo "--compute";;
        compat32) echo "--compat32";;
        display)  echo "--display";;
        graphics) echo "--graphics";;
        utility)  echo "--utility";;
        video)    echo "--video";;
        *)        exit 1;;
    esac
    return
}

# Same behavior as strconv.ParseBool in golang
parse_bool() {
    case "$1" in
        1|t|T|TRUE|true|True)    echo "true";;
        0|f|F|FALSE|false|False) echo "false";;
        *)                       exit 1;;
    esac
    return
}

usage() {
    cat <<EOF
nvidia-container-cli hook for LXC

Special arguments:
[ -h | --help ]: Print this help message and exit.

Optional arguments:
[ --no-load-kmods ]: Do not try to load the NVIDIA kernel modules.
[ --disable-require ]: Disable all the constraints of the form NVIDIA_REQUIRE_*.
[ --debug <path> ]: The path to the log file.
[ --ldcache <path> ]: The path to the host system's DSO cache.
[ --root <path> ]: The path to the driver root directory.
[ --ldconfig <path> ]: The path to the ldconfig binary, use a '@' prefix for a host path.
EOF
    return 0
}

options=$(getopt -o h -l help,no-load-kmods,disable-require,debug:,ldcache:,root:,ldconfig: -- "$@")
if [ $? -ne 0 ]; then
    usage
    exit 1
fi
eval set -- "$options"

CLI_LOAD_KMODS="true"
CLI_DISABLE_REQUIRE="false"
CLI_DEBUG=
CLI_LDCACHE=
CLI_ROOT=
CLI_LDCONFIG=

while :; do
    case "$1" in
        --help)             usage && exit 1;;
        --no-load-kmods)    CLI_LOAD_KMODS="false"; shift 1;;
        --disable-require)  CLI_DISABLE_REQUIRE="true"; shift 1;;
        --debug)            CLI_DEBUG=$2; shift 2;;
        --ldcache)          CLI_LDCACHE=$2; shift 2;;
        --root)             CLI_ROOT=$2; shift 2;;
        --ldconfig)         CLI_LDCONFIG=$2; shift 2;;
        --)                 shift 1; break;;
        *)                  break;;
    esac
done

HOOK_SECTION=
HOOK_TYPE=
case "${LXC_HOOK_VERSION:-0}" in
    0) HOOK_SECTION="${2:-}"; HOOK_TYPE="${3:-}";;
    1) HOOK_SECTION="${LXC_HOOK_SECTION:-}"; HOOK_TYPE="${LXC_HOOK_TYPE:-}";;
    *) echo "ERROR: Unsupported hook version: ${LXC_HOOK_VERSION}." >&2; exit 1;;
esac

if [ "${HOOK_SECTION}" != "lxc" ]; then
    echo "ERROR: Not running through LXC." >&2
    exit 1
fi

if [ "${HOOK_TYPE}" != "mount" ]; then
    echo "ERROR: This hook must be used as a \"mount\" hook." >&2
    exit 1
fi

USERNS=$(in_userns)
if [ "${USERNS}" != "yes" ]; then
    # This is a limitation of libnvidia-container.
    echo "FIXME: This hook currently only works in unprivileged mode." >&2
    exit 1
fi

if [ "${USERNS}" = "yes" ]; then
    CLI_LOAD_KMODS="false"
    if ! grep -q nvidia_uvm /proc/modules; then
        echo "WARN: Kernel module nvidia_uvm is not loaded, nvidia-container-cli might fail. Make sure the NVIDIA device driver is installed and loaded." >&2
    fi
fi

# https://github.com/nvidia/nvidia-container-runtime#nvidia_disable_require
if [ -n "${NVIDIA_DISABLE_REQUIRE:-}" ]; then
    if [ "$(parse_bool "${NVIDIA_DISABLE_REQUIRE}")" = "true" ]; then
        CLI_DISABLE_REQUIRE="true"
    fi
fi

if [ -z "${CLI_DEBUG}" ]; then
    if [ "${LXC_LOG_LEVEL}" = "DEBUG" ] || [ "${LXC_LOG_LEVEL}" = "TRACE" ]; then
        rootfs_path="${LXC_ROOTFS_PATH#*:}"
        hookdir="${rootfs_path/%rootfs/hook}"
        if mkdir -p "${hookdir}"; then
            CLI_DEBUG="${hookdir}/nvidia.log"
        fi
    fi
fi

# A '@' prefix means a host path.
if [ -z "${CLI_LDCONFIG}" ]; then
    if host_ldconfig=$(get_ldconfig); then
	CLI_LDCONFIG="@${host_ldconfig}"
    fi
fi

# https://github.com/nvidia/nvidia-container-runtime#nvidia_visible_devices
CLI_DEVICES="${NVIDIA_VISIBLE_DEVICES}"

# https://github.com/nvidia/nvidia-container-runtime#nvidia_driver_capabilities
CLI_CAPABILITIES=
if [ -n "${NVIDIA_DRIVER_CAPABILITIES:-}" ]; then
     CLI_CAPABILITIES="${NVIDIA_DRIVER_CAPABILITIES//,/ }"
fi

if [ "${CLI_CAPABILITIES}" = "all" ]; then
    CLI_CAPABILITIES="compute compat32 display graphics utility video"
fi

if [ -z "${CLI_CAPABILITIES}" ]; then
    CLI_CAPABILITIES="utility"
fi

global_args=()
configure_args=()

if [ -n "${CLI_DEBUG}" ]; then
    echo "INFO: Writing nvidia-container-cli log at ${CLI_DEBUG}." >&2
    global_args+=("--debug=${CLI_DEBUG}")
fi

if [ "${CLI_LOAD_KMODS}" = "true" ]; then
    global_args+=(--load-kmods)
fi

if [ "${USERNS}" = "yes" ]; then
    global_args+=(--user)
    configure_args+=(--no-cgroups)
fi

if [ -n "${CLI_LDCACHE}" ]; then
    global_args+=(--ldcache="${CLI_LDCACHE}")
fi

if [ -n "${CLI_ROOT}" ]; then
    global_args+=(--root="${CLI_ROOT}")
fi

if [ -n "${CLI_LDCONFIG}" ]; then
    configure_args+=(--ldconfig="${CLI_LDCONFIG}")
fi

if [ -n "${CLI_DEVICES}" ] && [ "${CLI_DEVICES}" != "none" ]; then
    configure_args+=(--device="${CLI_DEVICES}")
fi

for cap in ${CLI_CAPABILITIES}; do
    if arg=$(capability_to_cli "${cap}"); then
        configure_args+=("${arg}")
    else
        echo "ERROR: Unknown driver capability \"${cap}\"." >&2
        exit 1
    fi
done

# https://github.com/nvidia/nvidia-container-runtime#nvidia_require_
if [ "${CLI_DISABLE_REQUIRE}" = "false" ]; then
    for req in $(compgen -e "NVIDIA_REQUIRE_"); do
	configure_args+=("--require=${!req}")
    done
fi

if [ -d "/sys/kernel/security/apparmor" ]; then
    # Try to transition to the unconfined AppArmor profile.
    echo "changeprofile unconfined" > /proc/self/attr/current || true
fi

set -x
exec nvidia-container-cli ${global_args[@]} configure "${configure_args[@]}" "${LXC_ROOTFS_MOUNT}"
